import torch

from src.gfn.gfn import GFlowNet, LossTensor
from src.utils.trajectories import Trajectories

class TBGFlowNet(GFlowNet):

    """ 
    Trajectory balance parameterization:
    logZ, forward policy, backward policy.

    Default behavior:
    - No parameter sharing between forward/backward policy
    """
    def __init__(self, 
                env,
                config,
                forward_model,
                backward_model,
                tied: bool = False,
                ):

        super().__init__(env, config, forward_model, backward_model)
        self.logZ = torch.nn.Parameter(torch.tensor(1.0, device=self.device))
        self.optimizer = self._init_optimizer(tied, include_logZ=True)
        self.scheduler = self._init_scheduler(config["gfn"]["lr_schedule"])

    def _compute_loss_precursors(self, trajs: Trajectories, head=None):
        """
        Compute the log probabilities.
        """
        trajs.compute_logPF(self, head)
        trajs.compute_logPB(self)

    def loss(self, trajs: Trajectories, head=None) -> LossTensor:
        """
        Trajectory balance loss.
        """

        self._compute_loss_precursors(trajs, head)
        
        loss = (self.logZ + trajs.logPF - trajs.logPB - trajs.log_rewards.clip(min=self.log_reward_clip_min)).pow(2).mean()

        if torch.isnan(loss):
            raise ValueError("Loss is NaN.")
        
        return loss
    